// SPDX-License-Identifier: GPL-2.0-or-later
/*
 * Netlink parser library module
 * Copyright (c) 2024 Cisco Systems Inc.
 */

#include "zebra.h"

/* Only used with netlink, for now */
#ifdef HAVE_NETLINK

#include "lib/netlink_parser.h"
#include "lib/zlog.h"

/*
 * Init nl message in-place in a buffer
 */
struct nlmsghdr *nl_msg_init(int type, int flags, int seq, int pid,
			     uint8_t *buf, uint32_t buflen)
{
	struct nlmsghdr *n;

	if (buflen < sizeof(struct nlmsghdr))
		return NULL;

	n = (void *)buf;
	memset(n, 0, sizeof(struct nlmsghdr));

	n->nlmsg_type = type;
	n->nlmsg_flags = flags;
	n->nlmsg_seq = seq;
	n->nlmsg_pid = pid;
	n->nlmsg_len = NLMSG_LENGTH(0);

	return n;
}

/* Access data from embedded message header */
void nl_msg_get_data(const struct nlmsghdr *n, uint16_t *ptype, uint32_t *plen,
		     uint16_t *pflags)
{
	if (n) {
		*plen = n->nlmsg_len;
		*ptype = n->nlmsg_type;
		*pflags = n->nlmsg_flags;
	}
}

void netlink_parse_rtattr(struct rtattr **tb, int max, struct rtattr *rta, int len)
{
	memset(tb, 0, sizeof(struct rtattr *) * (max + 1));
	while (RTA_OK(rta, len)) {
		/*
		 * The type may be &'ed with NLA_F_NESTED
		 * which puts data in the upper 8 bits of the
		 * rta_type.  Mask it off and save the actual
		 * underlying value to be placed into the array.
		 * This way we don't accidently crash in the future
		 * when the kernel sends us new data and we try
		 * to write well beyond the end of the array.
		 */
		uint16_t type = rta->rta_type & NLA_TYPE_MASK;

		if (type <= max)
			tb[type] = rta;
		rta = RTA_NEXT(rta, len);
	}
}

/*
 * Given a message header, locate and parse out the atttributes
 */
struct nlmsghdr *netlink_parse_buf(struct rtattr **tb, int max, void *buf,
				   size_t len)
{
	struct nlmsghdr *msg;
	struct rtattr *rta;
	uint8_t *ptr;
	uint32_t offset = 0;

	msg = buf;

	if (len < msg->nlmsg_len)
		return NULL;

	/* Figure out where the attrs start, then use the attribute
	 * parse api.
	 */
	if (msg->nlmsg_type == RTM_NEWROUTE || msg->nlmsg_type == RTM_DELROUTE)
		offset = sizeof(struct rtmsg);

	/* Add clauses for the messages we want to use in this path... */

	if (offset > msg->nlmsg_len)
		return NULL;

	ptr = NLMSG_DATA(msg);
	ptr += offset;
	rta = (void *)ptr;

	netlink_parse_rtattr(tb, max, rta, msg->nlmsg_len - offset);

	return msg;
}

/**
 * netlink_parse_rtattr_nested() - Parses a nested route attribute
 * @tb:         Pointer to array for storing rtattr in.
 * @max:        Max number to store.
 * @rta:        Pointer to rtattr to look for nested items in.
 */
void netlink_parse_rtattr_nested(struct rtattr **tb, int max, struct rtattr *rta)
{
	netlink_parse_rtattr(tb, max, RTA_DATA(rta), RTA_PAYLOAD(rta));
}

bool nl_addraw_l(struct nlmsghdr *n, unsigned int maxlen, const void *data,
		 unsigned int len)
{
	if (NLMSG_ALIGN(n->nlmsg_len) + NLMSG_ALIGN(len) > maxlen) {
		zlog_err("ERROR message exceeded bound of %d", maxlen);
		return false;
	}

	memcpy(NLMSG_TAIL(n), data, len);
	memset((uint8_t *)NLMSG_TAIL(n) + len, 0, NLMSG_ALIGN(len) - len);
	n->nlmsg_len = NLMSG_ALIGN(n->nlmsg_len) + NLMSG_ALIGN(len);

	return true;
}

bool nl_attr_put(struct nlmsghdr *n, unsigned int maxlen, int type, const void *data,
		 unsigned int alen)
{
	int len;
	struct rtattr *rta;

	len = RTA_LENGTH(alen);

	if (NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(len) > maxlen)
		return false;

	rta = (struct rtattr *)(((char *)n) + NLMSG_ALIGN(n->nlmsg_len));
	rta->rta_type = type;
	rta->rta_len = len;

	if (data)
		memcpy(RTA_DATA(rta), data, alen);
	else
		assert(alen == 0);

	n->nlmsg_len = NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(len);

	return true;
}

bool nl_attr_put8(struct nlmsghdr *n, unsigned int maxlen, int type, uint8_t data)
{
	return nl_attr_put(n, maxlen, type, &data, sizeof(uint8_t));
}

bool nl_attr_put16(struct nlmsghdr *n, unsigned int maxlen, int type, uint16_t data)
{
	return nl_attr_put(n, maxlen, type, &data, sizeof(uint16_t));
}

bool nl_attr_put32(struct nlmsghdr *n, unsigned int maxlen, int type, uint32_t data)
{
	return nl_attr_put(n, maxlen, type, &data, sizeof(uint32_t));
}

bool nl_attr_put64(struct nlmsghdr *n, unsigned int maxlen, int type, uint64_t data)
{
	return nl_attr_put(n, maxlen, type, &data, sizeof(uint64_t));
}

struct rtattr *nl_attr_nest(struct nlmsghdr *n, unsigned int maxlen, int type)
{
	struct rtattr *nest = NLMSG_TAIL(n);

	if (!nl_attr_put(n, maxlen, type, NULL, 0))
		return NULL;

	nest->rta_type |= NLA_F_NESTED;
	return nest;
}

int nl_attr_nest_end(struct nlmsghdr *n, struct rtattr *nest)
{
	nest->rta_len = (uint8_t *)NLMSG_TAIL(n) - (uint8_t *)nest;
	return n->nlmsg_len;
}

struct rtnexthop *nl_attr_rtnh(struct nlmsghdr *n, unsigned int maxlen)
{
	struct rtnexthop *rtnh = (struct rtnexthop *)NLMSG_TAIL(n);

	if (NLMSG_ALIGN(n->nlmsg_len) + RTNH_ALIGN(sizeof(struct rtnexthop)) > maxlen)
		return NULL;

	memset(rtnh, 0, sizeof(struct rtnexthop));
	n->nlmsg_len = NLMSG_ALIGN(n->nlmsg_len) + RTA_ALIGN(sizeof(struct rtnexthop));

	return rtnh;
}

void nl_attr_rtnh_end(struct nlmsghdr *n, struct rtnexthop *rtnh)
{
	rtnh->rtnh_len = (uint8_t *)NLMSG_TAIL(n) - (uint8_t *)rtnh;
}

bool nl_rta_put(struct rtattr *rta, unsigned int maxlen, int type,
		const void *data, int alen)
{
	struct rtattr *subrta;
	int len = RTA_LENGTH(alen);

	if (RTA_ALIGN(rta->rta_len) + RTA_ALIGN(len) > maxlen) {
		zlog_err("ERROR max allowed bound %d exceeded for rtattr", maxlen);
		return false;
	}
	subrta = (struct rtattr *)(((char *)rta) + RTA_ALIGN(rta->rta_len));
	subrta->rta_type = type;
	subrta->rta_len = len;
	if (alen)
		memcpy(RTA_DATA(subrta), data, alen);
	rta->rta_len = NLMSG_ALIGN(rta->rta_len) + RTA_ALIGN(len);

	return true;
}

bool nl_rta_put16(struct rtattr *rta, unsigned int maxlen, int type, uint16_t data)
{
	return nl_rta_put(rta, maxlen, type, &data, sizeof(uint16_t));
}

bool nl_rta_put64(struct rtattr *rta, unsigned int maxlen, int type, uint64_t data)
{
	return nl_rta_put(rta, maxlen, type, &data, sizeof(uint64_t));
}

struct rtattr *nl_rta_nest(struct rtattr *rta, unsigned int maxlen, int type)
{
	struct rtattr *nest = RTA_TAIL(rta);

	if (nl_rta_put(rta, maxlen, type, NULL, 0))
		return NULL;

	nest->rta_type |= NLA_F_NESTED;

	return nest;
}

int nl_rta_nest_end(struct rtattr *rta, struct rtattr *nest)
{
	nest->rta_len = (uint8_t *)RTA_TAIL(rta) - (uint8_t *)nest;

	return rta->rta_len;
}

#define NLA_OK(nla, len)                                                       \
	((len) >= (int)sizeof(struct nlattr)                                   \
	 && (nla)->nla_len >= sizeof(struct nlattr)                            \
	 && (nla)->nla_len <= (len))
#define NLA_NEXT(nla, attrlen)                                                 \
	((attrlen) -= NLA_ALIGN((nla)->nla_len),                               \
	 (struct nlattr *)(((char *)(nla)) + NLA_ALIGN((nla)->nla_len)))

void netlink_parse_nlattr(struct nlattr **tb, int max, struct nlattr *nla,
			  int len)
{
	while (NLA_OK(nla, len)) {
		if (nla->nla_type <= max)
			tb[nla->nla_type] = nla;
		nla = NLA_NEXT(nla, len);
	}
}

#endif /* HAVE_NETLINK */
